/*
LEI Hongfaan
 * Copyright © 2020 LEI Hongfaan. Distributed under the MIT License.
 */
#nullable enable
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
using System.Threading.Tasks;

//#if NET5_0
public class LEIHongfaan_Programme
{

    public static unsafe double[,] Multiply(double[,] first, double[,] second)
    {
        if (first is null)
        {
            throw new ArgumentNullException(nameof(first));
        }
        if (0 != first.GetLowerBound(0))
        {
            throw new ArgumentException(nameof(first));
        }
        if (0 != first.GetLowerBound(1))
        {
            throw new ArgumentException(nameof(first));
        }
        if (second is null)
        {
            throw new ArgumentNullException(nameof(second));
        }
        if (0 != second.GetLowerBound(0))
        {
            throw new ArgumentException(nameof(second));
        }
        if (0 != second.GetLowerBound(1))
        {
            throw new ArgumentException(nameof(second));
        }

        if (first.GetLength(1) != second.GetLength(0))
        {
            throw new ArgumentException();
        }

        Process.GetCurrentProcess().PriorityClass = ProcessPriorityClass.High;

        var result = new double[first.GetLength(0), second.GetLength(1)];
        Multiply(first, second, result);
        return result;
    }

    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static Vector256<double> PermuteHorizontalAdd(Vector256<double> first, Vector256<double> second)
    {
        return Avx2.Permute4x64(Avx.HorizontalAdd(
            Avx2.Permute4x64(first, 0B11011000),
            Avx2.Permute4x64(second, 0B11011000)), 0B11011000);
    }

    [MethodImpl(MethodImplOptions.AggressiveOptimization)]
    private static unsafe void Multiply4xN_Nx4Transposed_Aligned(double* first, nint firstStride, double* secondTransposed, nint secondTransposedStride, nint count, double* result, nint resultStride)
    {
        Debug.Assert(IsNotNullDebug(first));
        Debug.Assert(IsNotNullDebug(secondTransposed));
        Debug.Assert(IsNotNullDebug(result));

        Debug.Assert(IsAligned(first, sizeof(Vector256<double>)));
        Debug.Assert(IsAligned(secondTransposed, sizeof(Vector256<double>)));
        Debug.Assert(IsAligned(result, sizeof(Vector256<double>)));

        Debug.Assert(IsMultipleOf(firstStride, 4));
        Debug.Assert(IsMultipleOf(secondTransposedStride, 4));
        Debug.Assert(IsMultipleOf(resultStride, 4));
        Debug.Assert(IsMultipleOf(count, 4));

        var rrow0 = Vector256<double>.Zero;
        var rrow1 = Vector256<double>.Zero;
        var rrow2 = Vector256<double>.Zero;
        var rrow3 = Vector256<double>.Zero;

        for (var i = count / 4; i > 0; --i)
        {
            var col0 = Avx.LoadAlignedVector256(0 * secondTransposedStride + secondTransposed);
            var col1 = Avx.LoadAlignedVector256(1 * secondTransposedStride + secondTransposed);
            var col2 = Avx.LoadAlignedVector256(2 * secondTransposedStride + secondTransposed);
            var col3 = Avx.LoadAlignedVector256(3 * secondTransposedStride + secondTransposed);

            var row = Avx.LoadAlignedVector256(0 * firstStride + first);
            rrow0 = Avx.Add(rrow0,
                PermuteHorizontalAdd(
                    Avx.HorizontalAdd(Avx.Multiply(row, col0), Avx.Multiply(row, col1)),
                    Avx.HorizontalAdd(Avx.Multiply(row, col2), Avx.Multiply(row, col3))));

            row = Avx.LoadAlignedVector256(1 * firstStride + first);
            rrow1 = Avx.Add(rrow1,
                PermuteHorizontalAdd(
                    Avx.HorizontalAdd(Avx.Multiply(row, col0), Avx.Multiply(row, col1)),
                    Avx.HorizontalAdd(Avx.Multiply(row, col2), Avx.Multiply(row, col3))));

            row = Avx.LoadAlignedVector256(2 * firstStride + first);
            rrow2 = Avx.Add(rrow2,
                PermuteHorizontalAdd(
                    Avx.HorizontalAdd(Avx.Multiply(row, col0), Avx.Multiply(row, col1)),
                    Avx.HorizontalAdd(Avx.Multiply(row, col2), Avx.Multiply(row, col3))));

            row = Avx.LoadAlignedVector256(3 * firstStride + first);
            rrow3 = Avx.Add(rrow3,
                PermuteHorizontalAdd(
                    Avx.HorizontalAdd(Avx.Multiply(row, col0), Avx.Multiply(row, col1)),
                    Avx.HorizontalAdd(Avx.Multiply(row, col2), Avx.Multiply(row, col3))));

            first += 4;
            secondTransposed += 4;
        }

        Avx.StoreAligned(0 * resultStride + result, rrow0);
        Avx.StoreAligned(1 * resultStride + result, rrow1);
        Avx.StoreAligned(2 * resultStride + result, rrow2);
        Avx.StoreAligned(3 * resultStride + result, rrow3);
    }

    private static unsafe bool IsNotNullDebug(void* ptr)
    {
        return 1024 * 1024 <= (nuint)ptr;
    }

    private static bool IsMultipleOf(nint ptr, int align)
    {
        if (0 == align)
        {
            return true;
        }
        Debug.Assert(1 == BitOperations.PopCount((uint)align));
        return 0 == (nuint)ptr % (uint)align;
    }

    private static unsafe bool IsAligned(void* ptr, int align)
    {
        return IsMultipleOf((nint)(nuint)ptr, align);
    }

    private static unsafe double* AlignCopy(double* unalignedMatrixDataPtr, nuint rowCount, nuint columnCount, byte* unalignedBufferPtr, nuint bufferByteSize, out nuint totalPaddedResultByteSize)
    {
        var totalCount = checked(columnCount * rowCount);
        var columnPaddingCount = (nuint)(-(nint)columnCount) % 4;
        var rowPaddingCount = (nuint)(-(nint)rowCount) % 4;
        var bufferPaddingByteSize = (nuint)(-(nint)unalignedBufferPtr) % (uint)sizeof(Vector256<double>);
        var totalResultByteSize = checked(sizeof(double) * (rowCount + rowPaddingCount) * (columnCount + columnPaddingCount) + bufferPaddingByteSize);
        _ = checked(bufferByteSize - totalResultByteSize);
        byte* r = unalignedBufferPtr + bufferPaddingByteSize;
        byte* bufferEnd = unalignedBufferPtr + bufferByteSize;
        var src = unalignedMatrixDataPtr;
        for (nuint i = 0; rowCount > i; ++i)
        {
            Buffer.MemoryCopy(src, r, bufferEnd - r, (long)(sizeof(double) * columnCount));
            src += columnCount;
            r += (long)(sizeof(double) * columnCount);
            new Span<byte>(r, sizeof(double) * (int)columnPaddingCount).Clear();
            r += sizeof(double) * (int)columnPaddingCount;
        }
        for (; unalignedBufferPtr + totalResultByteSize > r + 0x4000000; r += 0x4000000)
        {
            new Span<byte>(r, 0x4000000).Clear();
        }
        {
            new Span<byte>(r, (int)(unalignedBufferPtr + totalResultByteSize - r)).Clear();
        }
        totalPaddedResultByteSize = totalResultByteSize;
        var aligned = (double*)(unalignedBufferPtr + bufferPaddingByteSize);
        return aligned;
    }

    private static unsafe (IntPtr HGrobal, nint ByteCount) AlignCopy(IEnumerable<double[,]> ms, ICollection<IntPtr> alignedPointers, Func<nint, IntPtr>? allocator = default)
    {
        if (ms is null)
        {
            throw new ArgumentNullException(nameof(ms));
        }

        allocator = allocator ?? Marshal.AllocHGlobal;

        var cb = (nint)0;
        var t = ms.ToArray();
        foreach (var m in t)
        {
            var rowCount = (nuint)m.GetLength(0);
            var columnCount = (nuint)m.GetLength(1);
            checked
            {
                cb += sizeof(double) * (nint)GetPaddedLength(rowCount, columnCount);
            }
        }
        checked
        {
            cb += sizeof(Vector256<double>) - 1;
        }
        var cb0 = (nuint)cb;
        var p = allocator(cb);
        var pBuffer = (byte*)p;
        foreach (var m in t)
        {
            fixed (double* pValues = m)
            {
                var s = (nint)(nuint)AlignCopy(pValues, (nuint)m.GetLength(0), (nuint)m.GetLength(1), pBuffer, cb0, out var aaaa);
                pBuffer += aaaa;
                checked
                {
                    cb0 -= aaaa;
                }
                Debug.Assert(p <= s && s <= p + cb);
                alignedPointers.Add(s);
            }
        }

        return (p, cb);
    }

    private static unsafe nuint GetPaddedLength(nuint length0, nuint length1)
    {
        var c0 = (nuint)(-(nint)length1) % 4;
        var c1 = (nuint)(-(nint)length0) % 4;
        var c = checked((length0 + c1) * (length1 + c0));
        return c;
    }

    public static T[,] Transpose<T>(T[,] value)
    {
        if (value is null)
        {
            throw new ArgumentNullException(nameof(value));
        }
        if (0 != value.GetLowerBound(0))
        {
            throw new ArgumentException(nameof(value));
        }
        if (0 != value.GetLowerBound(1))
        {
            throw new ArgumentException(nameof(value));
        }

        var c0 = value.GetLongLength(1);
        var c1 = value.GetLongLength(0);
        var a = new T[c0, c1];
        for (long i0 = 0; c0 > i0; ++i0)
        {
            for (long i1 = 0; c1 > i1; ++i1)
            {
                a[i0, i1] = value[i1, i0];
            }
        }
        return a;
    }

    public static unsafe void Multiply(double[,] first, double[,] second, double[,] result)
    {
        if (first is null)
        {
            throw new ArgumentNullException(nameof(first));
        }
        if (0 != first.GetLowerBound(0))
        {
            throw new ArgumentException(nameof(first));
        }
        if (0 != first.GetLowerBound(1))
        {
            throw new ArgumentException(nameof(first));
        }
        if (second is null)
        {
            throw new ArgumentNullException(nameof(second));
        }
        if (0 != second.GetLowerBound(0))
        {
            throw new ArgumentException(nameof(second));
        }
        if (0 != second.GetLowerBound(1))
        {
            throw new ArgumentException(nameof(second));
        }
        if (result is null)
        {
            throw new ArgumentNullException(nameof(result));
        }
        if (0 != result.GetLowerBound(0))
        {
            throw new ArgumentException(nameof(result));
        }
        if (0 != result.GetLowerBound(1))
        {
            throw new ArgumentException(nameof(result));
        }

        if (first.GetLength(1) != second.GetLength(0))
        {
            throw new ArgumentException();
        }
        if (first.GetLength(0) > result.GetLength(0))
        {
            throw new ArgumentException();
        }
        if (second.GetLength(1) > result.GetLength(1))
        {
            throw new ArgumentException();
        }

        second = Transpose(second);

        var alignedPointers = new List<IntPtr>(3);
        IntPtr pUnmanagedHGlobalBuffer = default;
        nint cb;
        try
        {
            (pUnmanagedHGlobalBuffer, cb) = AlignCopy(new[] { first, second, result }, alignedPointers);
            MultiplySecondTransposedAligned((double*)alignedPointers[0], (nuint)AlignUnsafe(first.GetLength(0), 4), (double*)alignedPointers[1], (nuint)AlignUnsafe(second.GetLength(0), 4), (nuint)AlignUnsafe(second.GetLength(1), 4), (double*)alignedPointers[2]);
            var resultAlignedShadow = (double*)alignedPointers[2];
            UnalignCopy(resultAlignedShadow, result);
        }
        finally
        {
            Marshal.FreeHGlobal(pUnmanagedHGlobalBuffer);
        }
    }

    private static unsafe void UnalignCopy(double* aligned, double[,] result)
    {
        var length0 = result.GetLength(0);
        var length1 = result.GetLength(1);
        var rowStride = AlignUnsafe(length1, 4);

        var pSrc = aligned;
        fixed (double* unaligned = result)
        {
            var pDst = unaligned;
            for (var i0 = 0; length0 > i0; ++i0)
            {
                Buffer.MemoryCopy(pSrc, pDst, sizeof(double) * length1, sizeof(double) * length1);
                pSrc += rowStride;
                pDst += length1;
            }
        }
    }

    private static unsafe void MultiplySecondTransposedAligned(double* first, nuint firstRowCount, double* secondTransposed, nuint secondColumnCount, nuint firstColumnCount_secondRowCount, double* result)
    {
        MultiplySecondTransposedAligned(first, firstColumnCount_secondRowCount, firstRowCount, secondTransposed, firstColumnCount_secondRowCount, secondColumnCount, firstColumnCount_secondRowCount, result, secondColumnCount);
    }

    private static unsafe void MultiplySecondTransposedAligned(double* first, nuint firstRowStride, nuint firstRowCount, double* secondTransposed, nuint secondTransposedRowStride, nuint secondTransposedRowCount, nuint columnCount, double* result, nuint resultRowStride)
    {
        Parallel.For(0, checked((long)(firstRowCount / 4)), (g0) =>
        {
            var i0 = (nuint)g0 * 4;
            var pR = resultRowStride * i0 + result;
            var p0 = firstRowStride * i0 + first;
            Multiply4xN_NxMTransposedAligned(p0, firstRowStride, secondTransposed, secondTransposedRowStride, secondTransposedRowCount, columnCount, pR, resultRowStride);
        });
    }

    [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)]
    private static unsafe void Multiply4xN_NxMTransposedAligned(double* first, nuint firstRowStride, double* secondTransposed, nuint secondTransposedRowStride, nuint secondTransposedRowCount, nuint columnCount, double* result, nuint resultRowStride)
    {
        var pBuffer = stackalloc byte[sizeof(Vector256<double>) * (1 + 4 + 4) - 1];

        var p1 = secondTransposed;
        var pResult = result;
        for (nuint g1 = 0; g1 < secondTransposedRowCount / 4; ++g1)
        {
            // var i1 = g1 * 4;
            Multiply4xN_Nx4Transposed_Aligned(first, (nint)firstRowStride, p1, (nint)secondTransposedRowStride, (nint)columnCount, pResult, (nint)resultRowStride);
            pResult += 4;
            p1 += 4 * secondTransposedRowStride;
        }
    }

    private static nint AlignUnsafe(nint unaligned, int align)
    {
        if (0 == align)
        {
            return unaligned;
        }
        Debug.Assert(1 == BitOperations.PopCount((uint)align));
        var result = unaligned + ((-unaligned) & (align - 1));
        Debug.Assert(IsMultipleOf(result, align));
        Debug.Assert(unaligned <= result);
        Debug.Assert(align > result - unaligned);
        return result;
    }

    private static unsafe void* AlignUnsafe(void* ptr, int align)
    {
        return (void*)(nuint)AlignUnsafe((nint)(nuint)ptr, align);
    }
}
//#endif
